//! HTTP Loadbalance Layer
//!
//! This is a copy of `volo::loadbalance::layer` without the retry logic. Because retry needs the
//! `Req` has `Clone` trait, but HTTP body may be a stream, which cannot be cloned. So we remove
//! the retry related codes here.
//!
//! In addition, HTTP service can use DNS as service discover, so the default load balance uses a
//! DNS resolver for pick a target address (the DNS resolver picks only one because it does not
//! need load balance).

use std::{fmt::Debug, sync::Arc};

use async_broadcast::RecvError;
use motore::{layer::Layer, service::Service};
use volo::{
    context::Context,
    discovery::Discover,
    loadbalance::{LoadBalance, MkLbLayer, random::WeightedRandomBalance},
};

use super::dns::{DiscoverKey, DnsResolver};
use crate::{
    context::ClientContext,
    error::{
        ClientError,
        client::{lb_error, no_available_endpoint},
    },
    request::Request,
};

/// Default load balance with [`DnsResolver`]
pub type DefaultLb = LbConfig<WeightedRandomBalance<DiscoverKey>, DnsResolver>;
/// Default load balance service generated by [`DefaultLb`]
pub type DefaultLbService<S> =
    LoadBalanceService<WeightedRandomBalance<DiscoverKey>, DnsResolver, S>;

/// Load balance layer generator with a [`LoadBalance`] and a [`Discover`]
pub struct LbConfig<L, D> {
    load_balance: L,
    discover: D,
}

impl Default for DefaultLb {
    fn default() -> Self {
        LbConfig::new(WeightedRandomBalance::new(), DnsResolver::default())
    }
}

impl<L, D> LbConfig<L, D> {
    /// Create a new [`LbConfig`] using a [`LoadBalance`] and a [`Discover`]
    pub fn new(load_balance: L, discover: D) -> Self {
        LbConfig {
            load_balance,
            discover,
        }
    }

    /// Set a [`LoadBalance`] to the [`LbConfig`] and replace the previous one
    pub fn load_balance<NL>(self, load_balance: NL) -> LbConfig<NL, D> {
        LbConfig {
            load_balance,
            discover: self.discover,
        }
    }

    /// Set a [`Discover`] to the [`LbConfig`] and replace the previous one
    pub fn discover<ND>(self, discover: ND) -> LbConfig<L, ND> {
        LbConfig {
            load_balance: self.load_balance,
            discover,
        }
    }
}

impl<LB, D> MkLbLayer for LbConfig<LB, D> {
    type Layer = LoadBalanceLayer<LB, D>;

    fn make(self) -> Self::Layer {
        LoadBalanceLayer::new(self.load_balance, self.discover)
    }
}

/// [`Layer`] for load balance generated by [`LbConfig`]
#[derive(Clone, Default, Copy)]
pub struct LoadBalanceLayer<LB, D> {
    load_balance: LB,
    discover: D,
}

impl<LB, D> LoadBalanceLayer<LB, D> {
    fn new(load_balance: LB, discover: D) -> Self {
        LoadBalanceLayer {
            load_balance,
            discover,
        }
    }
}

impl<LB, D, S> Layer<S> for LoadBalanceLayer<LB, D>
where
    LB: LoadBalance<D>,
    D: Discover,
{
    type Service = LoadBalanceService<LB, D, S>;

    fn layer(self, inner: S) -> Self::Service {
        LoadBalanceService::new(self.load_balance, self.discover, inner)
    }
}

/// [`Service`] for load balance generated by [`LoadBalanceLayer`]
#[derive(Clone)]
pub struct LoadBalanceService<LB, D, S> {
    load_balance: Arc<LB>,
    discover: D,
    service: S,
}

impl<LB, D, S> LoadBalanceService<LB, D, S>
where
    LB: LoadBalance<D>,
    D: Discover,
{
    fn new(load_balance: LB, discover: D, service: S) -> Self {
        let lb = Arc::new(load_balance);

        let service = Self {
            load_balance: lb.clone(),
            discover,
            service,
        };

        let Some(mut channel) = service.discover.watch(None) else {
            return service;
        };

        tokio::spawn(async move {
            loop {
                match channel.recv().await {
                    Ok(recv) => lb.rebalance(recv),
                    Err(err) => match err {
                        RecvError::Closed => break,
                        _ => tracing::warn!("[Volo-HTTP] discovering subscription error: {err}"),
                    },
                }
            }
        });

        service
    }
}

impl<LB, D, S, B> Service<ClientContext, Request<B>> for LoadBalanceService<LB, D, S>
where
    LB: LoadBalance<D>,
    D: Discover,
    S: Service<ClientContext, Request<B>, Error = ClientError> + Send + Sync,
    B: Send,
{
    type Response = S::Response;
    type Error = S::Error;

    async fn call(
        &self,
        cx: &mut ClientContext,
        req: Request<B>,
    ) -> Result<Self::Response, Self::Error> {
        let callee = cx.rpc_info().callee();

        let mut picker = match &callee.address {
            None => self
                .load_balance
                .get_picker(callee, &self.discover)
                .await
                .map_err(lb_error)?,
            _ => {
                return self.service.call(cx, req).await;
            }
        };

        let addr = picker.next().ok_or_else(no_available_endpoint)?;
        cx.rpc_info_mut().callee_mut().set_address(addr);

        self.service.call(cx, req).await
    }
}

impl<LB, D, S> Debug for LoadBalanceService<LB, D, S>
where
    LB: Debug,
    D: Debug,
    S: Debug,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("LBService")
            .field("load_balancer", &self.load_balance)
            .field("discover", &self.discover)
            .finish()
    }
}
